/**
 * \file dnn/src/rocm/reduce_helper/column.hipinl
 *
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "src/rocm/reduce_helper.h.hip"
#include <limits>

namespace megdnn {
namespace rocm {
namespace reduce_intl {

/*!
 * each block has (1 << block_size_log2) threads and process fixed number of
 * rows; each row is processed by (1 << nr_thread_per_row_log2) threads.
 *
 * need a padding of max_nr_threads_per_row/2 elements after shared memory
 */
template<int block_size_log2, int max_nr_threads_per_row,
    class Op, int warp_size>
__global__ void kern_column(Op op,
        uint32_t A, uint32_t B, uint32_t nr_thread_per_row_log2,
        uint32_t sm_width_byte) {
    typedef typename Op::wtype wtype;
    // shared mem: matrix(nr_row_per_block, nr_thread_per_row)
    HIP_DYNAMIC_SHARED( uint8_t, sub_block_raw)

    uint32_t nr_row_per_block =
                1 << (block_size_log2 - nr_thread_per_row_log2),
             nr_thread_per_row = 1 << nr_thread_per_row_log2,
             row_num = hipThreadIdx_x >> nr_thread_per_row_log2,
             // tid in current row
             tid = hipThreadIdx_x - (row_num << nr_thread_per_row_log2),
             a = hipBlockIdx_x * nr_row_per_block + row_num;

    volatile wtype* row = (wtype*)(sub_block_raw + row_num * sm_width_byte);
    // sum columns of src[a0:a1] and store in row
    {
        uint32_t base = min(a, A - 1) * B;
        wtype csum = op.read(base + tid);
        for (int c = tid + nr_thread_per_row; c < B; c += nr_thread_per_row) {
            csum = Op::apply(csum, op.read(base + c));
        }
        row[tid] = csum;
    }

#pragma unroll
    for (uint32_t i = max_nr_threads_per_row / 2; i; i >>= 1) {
        bool cond = nr_thread_per_row >= i * 2 && tid < i;
        if (i >= warp_size) {
            __syncthreads();
        } else {
        }
        if (cond) {
            wtype v0 = row[tid];
            wtype v1 = Op::apply(v0, row[tid + i]);
            row[tid] = cond ? v1 : v0;
        }
    }

    if (a < A && !tid) {
        op.write(a, row[0]);
    }
}

template<class Op,
    uint32_t max_nr_threads_per_row, uint32_t block_size_log2,
    uint32_t warp_size>
void _do_run_column(uint32_t A, uint32_t B, hipStream_t stream,
        const Op &op) {
    typedef typename Op::wtype wtype;
    const uint32_t block_size = 1 << block_size_log2;
    uint32_t nr_thread_per_row = 1, nr_thread_per_row_log2 = 0;

    while (nr_thread_per_row < max_nr_threads_per_row &&
            nr_thread_per_row * 2 <= B) {
        ++ nr_thread_per_row_log2;
        nr_thread_per_row *= 2;
    }
    // now: nr_thread_per_row <= B < nr_thread_per_row * 2

    if (B <= max_nr_threads_per_row * 4) {
        // find nr_thread_per_row with minimal wasted threads
        uint32_t min_cost = std::numeric_limits<uint32_t>::max(),
                 min_cost_th = 0;
        for (uint32_t i = warp_size; i <= nr_thread_per_row; i *= 2) {
            uint32_t cost = (i - B % i) % i;
            if (cost < min_cost) {
                min_cost = cost;
                min_cost_th = i;
            }
        }
        if (min_cost_th) {
            nr_thread_per_row = min_cost_th;
            while ((1u << nr_thread_per_row_log2) != nr_thread_per_row)
                -- nr_thread_per_row_log2;
        }
    }

    uint32_t nr_row_per_block = block_size / nr_thread_per_row,
             nr_blk = DIVUP(A, nr_row_per_block),
             sm_width_word32 = DIVUP(nr_thread_per_row * sizeof(wtype), 4ul);

    // gcd(sm_width_word32, BANKS) should be 1 to avoid bank confliction
    // iff sm_width_word32 is odd
    sm_width_word32 += !(sm_width_word32 % 2);
    uint32_t sm_width_byte = sm_width_word32 * 4,
             sm_size = nr_row_per_block * sm_width_byte +
                 sizeof(wtype) * max_nr_threads_per_row / 2;

    void (*kptr)(Op op,
        uint32_t A, uint32_t B, uint32_t nr_thread_per_row_log2,
        uint32_t sm_width_byte);
    if (nr_thread_per_row <= max_nr_threads_per_row / 4) {
        kptr = kern_column<block_size_log2, max_nr_threads_per_row / 4,
             Op, warp_size>;
    } else if (nr_thread_per_row <= max_nr_threads_per_row / 2) {
        kptr = kern_column<block_size_log2, max_nr_threads_per_row / 2,
             Op, warp_size>;
    } else {
        kptr = kern_column<block_size_log2, max_nr_threads_per_row,
             Op, warp_size>;
    }
    hipLaunchKernelGGL((kptr), dim3(nr_blk), dim3(block_size), sm_size, stream, 
            op, A, B, nr_thread_per_row_log2, sm_width_byte);
    after_kernel_launch();
}


// use struct to allow default template arguments in C++-03
/*!
 * \brief start the hip kernel to reduce in column direction of a matrix
 * \tparam max_nr_threads_per_row max number of threads to reduce each row
 * \tparam block_size_log2 log2 of threads in a block
 * \tparam warp_size size of warp on the device
 */
template<class Op,
    uint32_t max_nr_threads_per_row=64, uint32_t block_size_log2=7,
    uint32_t warp_size=32>
struct run_column {
    static void run(
            uint32_t A, uint32_t B, hipStream_t stream,
            const Op &op) {
        return _do_run_column<Op, max_nr_threads_per_row,
        block_size_log2, warp_size>(A, B, stream, op);
    }
};

} // namespace reduce_intl
} // namespace rocm
} // namespace megdnn

// vim: ft=cpp syntax=cpp.doxygen
